/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "transformer_utils.h"
#include "../stub/transformer.h"
#include "attr_utils.h"
namespace ge {
bool NodeShapeTransUtils::Init() {
  if (op_desc_ == nullptr) {
    return false;
  }
  in_num_ = op_desc_->GetInputDescCount();
  out_num_ = op_desc_->GetOutputDescCount();
  /*
  map_format_in_.resize(in_num_, FORMAT_RESERVED);
  map_ori_format_in_.resize(in_num_, FORMAT_RESERVED);
  map_dtype_in_.resize(in_num_, DT_UNDEFINED);
  map_format_out_.resize(out_num_, FORMAT_RESERVED);
  map_ori_format_out_.resize(out_num_, FORMAT_RESERVED);
  map_dtype_out_.resize(out_num_, DT_UNDEFINED);
   */
  return true;
}
bool NodeShapeTransUtils::CatchFormatAndShape() {
  for (size_t i = 0; i < in_num_; i++) {
    auto tensor_desc_input = op_desc_->MutableInputDesc(static_cast<int>(i));
    if (tensor_desc_input == nullptr) {
      continue;
    }
    auto format = tensor_desc_input->GetFormat();
    auto ori_format = tensor_desc_input->GetOriginFormat();
    if (format == ori_format) {
      continue;
    }
    map_format_in_[i] = format;
    map_ori_format_in_[i] = ori_format;
    map_dtype_in_[i] = tensor_desc_input->GetDataType();
    tensor_desc_input->SetFormat(ori_format);
    tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape());
  }

  for (size_t i = 0; i < out_num_; i++) {
    auto tensor_desc_output = op_desc_->MutableOutputDesc(static_cast<int>(i));
    if (tensor_desc_output == nullptr) {
      continue;
    }
    auto format = tensor_desc_output->GetFormat();
    auto ori_format = tensor_desc_output->GetOriginFormat();
    if (format == ori_format) {
      continue;
    }
    map_format_out_[i] = format;
    map_ori_format_out_[i] = ori_format;
    map_dtype_out_[i] = tensor_desc_output->GetDataType();

    if (format == ori_format) {
      continue;
    }
    tensor_desc_output->SetFormat(ori_format);
  }

  return true;
}

bool NodeShapeTransUtils::UpdateFormatAndShape() {
  transformer::ShapeTransferAccordingToFormat shape_transfer;
  for (size_t i = 0; i < in_num_; i++) {
    auto tensor_desc_input = op_desc_->MutableInputDesc(static_cast<int>(i));
    if (tensor_desc_input == nullptr) {
      continue;
    }

    if (map_format_in_[i] == FORMAT_RESERVED) {
      tensor_desc_input->SetOriginFormat(tensor_desc_input->GetFormat());
      tensor_desc_input->SetOriginShape(tensor_desc_input->MutableShape());
      continue;
    }
    auto ori_format = tensor_desc_input->GetFormat();
    auto ori_shape = tensor_desc_input->MutableShape();
    auto curr_format = map_format_in_[i];
    if (ori_format == curr_format || curr_format == FORMAT_ND) {
      continue;
    }
    ge::DataType dtype =  map_dtype_in_[i];

    int infer_reshape_type;
    (void) AttrUtils::Get(*tensor_desc_input, "_infer_shape_type", infer_reshape_type);
    bool is_success = transformer::ExpandDimension(op_desc_->GetType(), ori_format, curr_format, i,
                                                   infer_reshape_type, ori_shape);
    if (!is_success) {
      return FAILED;
    }

    Shape out_shape;
    transformer::ShapeAndFormat shape_and_format_info {ori_shape, out_shape, ori_format, curr_format, dtype,
                                                       0};
    shape_transfer.GetShapeAccordingToFormat(shape_and_format_info);
    tensor_desc_input->SetFormat(curr_format);
    tensor_desc_input->SetShape(out_shape);
  }

  for (size_t i = 0; i < out_num_; i++) {
    auto tensor_desc_output = op_desc_->MutableOutputDesc(static_cast<int>(i));
    if (tensor_desc_output == nullptr) {
      continue;
    }
    // if can not find saved info, it says format and origin format is same when catched
    if (map_ori_format_out_[i] == FORMAT_RESERVED) {
      tensor_desc_output->SetOriginFormat(tensor_desc_output->GetFormat());
      tensor_desc_output->SetOriginShape(tensor_desc_output->MutableShape());
      continue;
    }
    auto ori_shape = tensor_desc_output->MutableShape();
    auto curr_format = tensor_desc_output->GetFormat();
    if (curr_format != map_ori_format_out_[i]) {
      return FAILED;
    }
    tensor_desc_output->SetOriginShape(ori_shape);
    auto saved_format = map_format_out_[i];
    if (curr_format == saved_format || saved_format == FORMAT_ND) {
      continue;
    }
    tensor_desc_output->SetFormat(saved_format);
    ge::DataType dtype =  tensor_desc_output->GetDataType();

    // FE set and Ge get for PadDimention
    int infer_reshape_type;
    (void) AttrUtils::Get(*tensor_desc_output, "_infer_shape_type", infer_reshape_type);
    bool is_success = transformer::ExpandDimension(op_desc_->GetType(), curr_format, saved_format, i,
                                                   infer_reshape_type, ori_shape);
    if (!is_success) {
      return FAILED;
    }

    Shape out_dims;
    transformer::ShapeAndFormat shape_and_format_info {ori_shape, out_dims, curr_format, saved_format, dtype, 0};
    shape_transfer.GetShapeAccordingToFormat(shape_and_format_info);
    tensor_desc_output->SetShape(out_dims);
  }
  return true;
}
} // namespace ge